{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mnist分类任务：\n",
    "\n",
    "- 网络基本构建与训练方法，常用函数解析\n",
    "\n",
    "- torch.nn.functional模块\n",
    "\n",
    "- nn.Module模块\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 读取Mnist数据集\n",
    "- 会自动进行下载"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-10-11T13:54:02.045667Z",
     "start_time": "2025-10-11T13:54:01.792433Z"
    }
   },
   "source": [
    "%matplotlib inline"
   ],
   "outputs": [],
   "execution_count": 2
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-10-11T13:54:03.039595Z",
     "start_time": "2025-10-11T13:54:03.036029Z"
    }
   },
   "source": [
    "from pathlib import Path\n",
    "import requests\n",
    "\n",
    "DATA_PATH = Path(\"data\")\n",
    "PATH = DATA_PATH / \"mnist\"\n",
    "\n",
    "PATH.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "URL = \"http://deeplearning.net/data/mnist/\"\n",
    "FILENAME = \"mnist.pkl.gz\"\n",
    "\n",
    "if not (PATH / FILENAME).exists():\n",
    "        content = requests.get(URL + FILENAME).content\n",
    "        (PATH / FILENAME).open(\"wb\").write(content)"
   ],
   "outputs": [],
   "execution_count": 3
  },
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-10-11T13:54:04.978646Z",
     "start_time": "2025-10-11T13:54:04.706696Z"
    }
   },
   "source": [
    "import pickle\n",
    "import gzip\n",
    "\n",
    "with gzip.open((PATH / FILENAME).as_posix(), \"rb\") as f:\n",
    "        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding=\"latin-1\")"
   ],
   "outputs": [],
   "execution_count": 4
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "784是mnist数据集每个样本的像素点个数"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-10-11T13:54:21.228870Z",
     "start_time": "2025-10-11T13:54:21.174160Z"
    }
   },
   "source": [
    "from matplotlib import pyplot\n",
    "import numpy as np\n",
    "\n",
    "pyplot.imshow(x_train[0].reshape((28, 28)), cmap=\"gray\")\n",
    "print(x_train.shape)"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(50000, 784)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ],
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZrElEQVR4nO3df5BVZf0H8Gf9wYoKSyvCsgIKqFgiOBkQqaSJIJUjSI2azWA5Ohg4KokNTopWtqZpDkXKHw1kKf6YCU2moRRkmRJwQIlxLMZlKDABk9rll4DC+c45zO6XVZDOsstz997Xa+aZy733fPYezp497/uc89znliVJkgQAOMKOOtIvCAApAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEMUxocDs3bs3vPPOO6FTp06hrKws9uoAkFM6v8HWrVtDdXV1OOqoo9pPAKXh06tXr9irAcBhWr9+fejZs2f7OQWX9nwAaP8OdTxvswCaMWNGOO2008Jxxx0Xhg4dGl599dX/qc5pN4DicKjjeZsE0NNPPx0mT54cpk2bFl577bUwaNCgMGrUqPDuu++2xcsB0B4lbWDIkCHJxIkTm+7v2bMnqa6uTmpqag5Z29DQkM7OrWmapoX23dLj+Sdp9R7Q7t27w4oVK8KIESOaHktHQaT3lyxZ8rHld+3aFbZs2dKsAVD8Wj2A3nvvvbBnz57QvXv3Zo+n9zdu3Pix5WtqakJFRUVTMwIOoDREHwU3derU0NDQ0NTSYXsAFL9W/xxQ165dw9FHHx02bdrU7PH0flVV1ceWLy8vzxoApaXVe0AdOnQI5513XliwYEGz2Q3S+8OGDWvtlwOgnWqTmRDSIdjjx48Pn/vc58KQIUPCI488ErZv3x6+9a1vtcXLAdAOtUkAXXXVVeHf//53uPvuu7OBB+eee26YP3/+xwYmAFC6ytKx2KGApMOw09FwALRv6cCyzp07F+4oOABKkwACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKI6J87JQmI4++ujcNRUVFaFQTZo0qUV1xx9/fO6a/v37566ZOHFi7pqf/vSnuWuuueaa0BI7d+7MXXP//ffnrrn33ntDKdIDAiAKAQRAcQTQPffcE8rKypq1s846q7VfBoB2rk2uAZ199tnhpZde+v8XOcalJgCaa5NkSAOnqqqqLX40AEWiTa4BvfXWW6G6ujr07ds3XHvttWHdunUHXXbXrl1hy5YtzRoAxa/VA2jo0KFh9uzZYf78+eHRRx8Na9euDRdeeGHYunXrAZevqanJhrE2tl69erX2KgFQCgE0evTo8PWvfz0MHDgwjBo1KvzhD38I9fX14Zlnnjng8lOnTg0NDQ1Nbf369a29SgAUoDYfHdClS5dw5plnhrq6ugM+X15enjUASkubfw5o27ZtYc2aNaFHjx5t/VIAlHIA3X777aG2tjb84x//CK+88koYO3ZsNr1JS6fCAKA4tfopuLfffjsLm82bN4eTTz45XHDBBWHp0qXZvwGgzQLoqaeeau0fSYHq3bt37poOHTrkrvnCF76QuyZ949PSa5Z5jRs3rkWvVWzSN595TZ8+PXdNelYlr4ONwj2Uv/71r7lr0jNA/G/MBQdAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAoihLkiQJBWTLli3ZV3Nz5Jx77rktqlu4cGHuGr/b9mHv3r25a7797W+36PvCjoQNGza0qO6///1v7prVq1e36LWKUfot1507dz7o83pAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFMfEeVkKybp161pUt3nz5tw1ZsPeZ9myZblr6uvrc9dcfPHFoSV2796du+Y3v/lNi16L0qUHBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiMBkp4T//+U+L6qZMmZK75qtf/Wrumtdffz13zfTp08ORsnLlytw1l156ae6a7du35645++yzQ0vccsstLaqDPPSAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUZUmSJKGAbNmyJVRUVMReDdpI586dc9ds3bo1d83MmTNDS1x//fW5a775zW/mrpkzZ07uGmhvGhoaPvFvXg8IgCgEEADtI4AWL14cLr/88lBdXR3KysrCc8891+z59Ize3XffHXr06BE6duwYRowYEd56663WXGcASjGA0i/FGjRoUJgxY8YBn3/ggQeyLwN77LHHwrJly8IJJ5wQRo0aFXbu3Nka6wtAqX4j6ujRo7N2IGnv55FHHgnf//73wxVXXJE99vjjj4fu3btnPaWrr7768NcYgKLQqteA1q5dGzZu3JiddmuUjmgbOnRoWLJkyQFrdu3alY18278BUPxaNYDS8EmlPZ79pfcbn/uompqaLKQaW69evVpzlQAoUNFHwU2dOjUbK97Y1q9fH3uVAGhvAVRVVZXdbtq0qdnj6f3G5z6qvLw8+6DS/g2A4teqAdSnT58saBYsWND0WHpNJx0NN2zYsNZ8KQBKbRTctm3bQl1dXbOBBytXrgyVlZWhd+/e4dZbbw0/+tGPwhlnnJEF0l133ZV9ZmjMmDGtve4AlFIALV++PFx88cVN9ydPnpzdjh8/PsyePTvccccd2WeFbrzxxlBfXx8uuOCCMH/+/HDccce17poD0K6ZjJSi9OCDD7aorvENVR61tbW5a/b/qML/au/evblrICaTkQJQkAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIjCbNgUpRNOOKFFdS+88ELumi9+8Yu5a0aPHp275k9/+lPuGojJbNgAFCQBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFGYjBT2069fv9w1r732Wu6a+vr63DUvv/xy7prly5eHlpgxY0bumgI7lFAATEYKQEESQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFyUjhMI0dOzZ3zaxZs3LXdOrUKRwpd955Z+6axx9/PHfNhg0bctfQfpiMFICCJIAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCpORQgQDBgzIXfPwww/nrrnkkkvCkTJz5szcNffdd1/umn/961+5a4jDZKQAFCQBBED7CKDFixeHyy+/PFRXV4eysrLw3HPPNXv+uuuuyx7fv1122WWtuc4AlGIAbd++PQwaNCjMmDHjoMukgZN+0VRjmzNnzuGuJwBF5pi8BaNHj87aJykvLw9VVVWHs14AFLk2uQa0aNGi0K1bt9C/f/9w0003hc2bNx902V27dmUj3/ZvABS/Vg+g9PRb+t3wCxYsCD/5yU9CbW1t1mPas2fPAZevqanJhl03tl69erX2KgFQDKfgDuXqq69u+vc555wTBg4cGPr165f1ig70mYSpU6eGyZMnN91Pe0BCCKD4tfkw7L59+4auXbuGurq6g14vSj+otH8DoPi1eQC9/fbb2TWgHj16tPVLAVDMp+C2bdvWrDezdu3asHLlylBZWZm1e++9N4wbNy4bBbdmzZpwxx13hNNPPz2MGjWqtdcdgFIKoOXLl4eLL7646X7j9Zvx48eHRx99NKxatSr8+te/DvX19dmHVUeOHBl++MMfZqfaAKCRyUihnejSpUvumnTWkpaYNWtW7pp01pO8Fi5cmLvm0ksvzV1DHCYjBaAgCSAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIXZsIGP2bVrV+6aY47J/e0u4cMPP8xd05LvFlu0aFHuGg6f2bABKEgCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKLIP3sgcNgGDhyYu+ZrX/ta7prBgweHlmjJxKIt8eabb+auWbx4cZusC0eeHhAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiMJkpLCf/v37566ZNGlS7porr7wyd01VVVUoZHv27Mlds2HDhtw1e/fuzV1DYdIDAiAKAQRAFAIIgCgEEABRCCAAohBAAEQhgACIQgABEIUAAiAKAQRAFAIIgCgEEABRmIyUgteSSTivueaaFr1WSyYWPe2000KxWb58ee6a++67L3fN73//+9w1FA89IACiEEAAFH4A1dTUhMGDB4dOnTqFbt26hTFjxoTVq1c3W2bnzp1h4sSJ4aSTTgonnnhiGDduXNi0aVNrrzcApRRAtbW1WbgsXbo0vPjii+GDDz4II0eODNu3b29a5rbbbgsvvPBCePbZZ7Pl33nnnRZ9+RYAxS3XIIT58+c3uz979uysJ7RixYowfPjw0NDQEH71q1+FJ598MnzpS1/Klpk1a1b49Kc/nYXW5z//+dZdewBK8xpQGjipysrK7DYNorRXNGLEiKZlzjrrrNC7d++wZMmSA/6MXbt2hS1btjRrABS/FgdQ+r3st956azj//PPDgAEDssc2btwYOnToELp06dJs2e7du2fPHey6UkVFRVPr1atXS1cJgFIIoPRa0BtvvBGeeuqpw1qBqVOnZj2pxrZ+/frD+nkAFPEHUdMP682bNy8sXrw49OzZs9kHBnfv3h3q6+ub9YLSUXAH+zBheXl51gAoLbl6QEmSZOEzd+7csHDhwtCnT59mz5933nnh2GOPDQsWLGh6LB2mvW7dujBs2LDWW2sASqsHlJ52S0e4Pf/889lngRqv66TXbjp27JjdXn/99WHy5MnZwITOnTuHm2++OQsfI+AAaHEAPfroo9ntRRdd1OzxdKj1ddddl/37Zz/7WTjqqKOyD6CmI9xGjRoVfvnLX+Z5GQBKQFmSnlcrIOkw7LQnReFLRzfm9ZnPfCZ3zS9+8YvcNenw/2KzbNmy3DUPPvhgi14rPcvRkpGxsL90YFl6JuxgzAUHQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQAC0n29EpXCl38OU18yZM1v0Wueee27umr59+4Zi88orr+Sueeihh3LX/PGPf8xd8/777+eugSNFDwiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARGEy0iNk6NChuWumTJmSu2bIkCG5a0455ZRQbHbs2NGiuunTp+eu+fGPf5y7Zvv27blroNjoAQEQhQACIAoBBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKExGeoSMHTv2iNQcSW+++Wbumnnz5uWu+fDDD3PXPPTQQ6El6uvrW1QH5KcHBEAUAgiAKAQQAFEIIACiEEAARCGAAIhCAAEQhQACIAoBBEAUAgiAKAQQAFEIIACiKEuSJAkFZMuWLaGioiL2agBwmBoaGkLnzp0P+rweEABRCCAACj+AampqwuDBg0OnTp1Ct27dwpgxY8Lq1aubLXPRRReFsrKyZm3ChAmtvd4AlFIA1dbWhokTJ4alS5eGF198MXzwwQdh5MiRYfv27c2Wu+GGG8KGDRua2gMPPNDa6w1AKX0j6vz585vdnz17dtYTWrFiRRg+fHjT48cff3yoqqpqvbUEoOgcdbgjHFKVlZXNHn/iiSdC165dw4ABA8LUqVPDjh07Dvozdu3alY18278BUAKSFtqzZ0/yla98JTn//PObPT5z5sxk/vz5yapVq5Lf/va3ySmnnJKMHTv2oD9n2rRp6TBwTdM0LRRXa2ho+MQcaXEATZgwITn11FOT9evXf+JyCxYsyFakrq7ugM/v3LkzW8nGlv682BtN0zRNC20eQLmuATWaNGlSmDdvXli8eHHo2bPnJy47dOjQ7Lauri7069fvY8+Xl5dnDYDSkiuA0h7TzTffHObOnRsWLVoU+vTpc8ialStXZrc9evRo+VoCUNoBlA7BfvLJJ8Pzzz+ffRZo48aN2ePp1DkdO3YMa9asyZ7/8pe/HE466aSwatWqcNttt2Uj5AYOHNhW/wcA2qM8130Odp5v1qxZ2fPr1q1Lhg8fnlRWVibl5eXJ6aefnkyZMuWQ5wH3ly4b+7ylpmmaFg67HerYbzJSANqEyUgBKEgCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgCgEEQBQFF0BJksReBQCOwPG84AJo69atsVcBgCNwPC9LCqzLsXfv3vDOO++ETp06hbKysmbPbdmyJfTq1SusX78+dO7cOZQq22Ef22Ef22Ef26FwtkMaK2n4VFdXh6OOOng/55hQYNKV7dmz5ycuk27UUt7BGtkO+9gO+9gO+9gOhbEdKioqDrlMwZ2CA6A0CCAAomhXAVReXh6mTZuW3ZYy22Ef22Ef22Ef26H9bYeCG4QAQGloVz0gAIqHAAIgCgEEQBQCCIAo2k0AzZgxI5x22mnhuOOOC0OHDg2vvvpqKDX33HNPNjvE/u2ss84KxW7x4sXh8ssvzz5Vnf6fn3vuuWbPp+No7r777tCjR4/QsWPHMGLEiPDWW2+FUtsO11133cf2j8suuywUk5qamjB48OBsppRu3bqFMWPGhNWrVzdbZufOnWHixInhpJNOCieeeGIYN25c2LRpUyi17XDRRRd9bH+YMGFCKCTtIoCefvrpMHny5Gxo4WuvvRYGDRoURo0aFd59991Qas4+++ywYcOGpvbnP/85FLvt27dnv/P0TciBPPDAA2H69OnhscceC8uWLQsnnHBCtn+kB6JS2g6pNHD23z/mzJkTikltbW0WLkuXLg0vvvhi+OCDD8LIkSOzbdPotttuCy+88EJ49tlns+XTqb2uvPLKUGrbIXXDDTc02x/Sv5WCkrQDQ4YMSSZOnNh0f8+ePUl1dXVSU1OTlJJp06YlgwYNSkpZusvOnTu36f7evXuTqqqq5MEHH2x6rL6+PikvL0/mzJmTlMp2SI0fPz654oorklLy7rvvZtuitra26Xd/7LHHJs8++2zTMn/729+yZZYsWZKUynZIffGLX0xuueWWpJAVfA9o9+7dYcWKFdlplf3ni0vvL1myJJSa9NRSegqmb9++4dprrw3r1q0LpWzt2rVh48aNzfaPdA6q9DRtKe4fixYtyk7J9O/fP9x0001h8+bNoZg1NDRkt5WVldlteqxIewP77w/paerevXsX9f7Q8JHt0OiJJ54IXbt2DQMGDAhTp04NO3bsCIWk4CYj/aj33nsv7NmzJ3Tv3r3Z4+n9v//976GUpAfV2bNnZweXtDt97733hgsvvDC88cYb2bngUpSGT+pA+0fjc6UiPf2Wnmrq06dPWLNmTbjzzjvD6NGjswPv0UcfHYpNOnP+rbfeGs4///zsAJtKf+cdOnQIXbp0KZn9Ye8BtkPqG9/4Rjj11FOzN6yrVq0K3/ve97LrRL/73e9CoSj4AOL/pQeTRgMHDswCKd3BnnnmmXD99ddHXTfiu/rqq5v+fc4552T7SL9+/bJe0SWXXBKKTXoNJH3zVQrXQVuyHW688cZm+0M6SCfdD9I3J+l+UQgK/hRc2n1M3719dBRLer+qqiqUsvRd3plnnhnq6upCqWrcB+wfH5eepk3/fopx/5g0aVKYN29eePnll5t9fUv6O09P29fX15fE/jDpINvhQNI3rKlC2h8KPoDS7vR5550XFixY0KzLmd4fNmxYKGXbtm3L3s2k72xKVXq6KT2w7L9/pF/IlY6GK/X94+23386uARXT/pGOv0gPunPnzg0LFy7Mfv/7S48Vxx57bLP9IT3tlF4rLab9ITnEdjiQlStXZrcFtT8k7cBTTz2VjWqaPXt28uabbyY33nhj0qVLl2Tjxo1JKfnud7+bLFq0KFm7dm3yl7/8JRkxYkTStWvXbARMMdu6dWvy+uuvZy3dZR9++OHs3//85z+z5++///5sf3j++eeTVatWZSPB+vTpk7z//vtJqWyH9Lnbb789G+mV7h8vvfRS8tnPfjY544wzkp07dybF4qabbkoqKiqyv4MNGzY0tR07djQtM2HChKR3797JwoULk+XLlyfDhg3LWjG56RDboa6uLvnBD36Q/f/T/SH92+jbt28yfPjwpJC0iwBK/fznP892qg4dOmTDspcuXZqUmquuuirp0aNHtg1OOeWU7H66oxW7l19+OTvgfrSlw44bh2LfddddSffu3bM3KpdcckmyevXqpJS2Q3rgGTlyZHLyySdnw5BPPfXU5IYbbii6N2kH+v+nbdasWU3LpG88vvOd7ySf+tSnkuOPPz4ZO3ZsdnAupe2wbt26LGwqKyuzv4nTTz89mTJlStLQ0JAUEl/HAEAUBX8NCIDiJIAAiEIAARCFAAIgCgEEQBQCCIAoBBAAUQggAKIQQABEIYAAiEIAARCFAAIgxPB/C6jvgw08HNgAAAAASUVORK5CYII="
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "execution_count": 7
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"./img/4.png\" alt=\"FAO\" width=\"790\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"./img/5.png\" alt=\"FAO\" width=\"790\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意数据需转换成tensor才能参与后续建模训练\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        ...,\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
      "        [0., 0., 0.,  ..., 0., 0., 0.]]) tensor([5, 0, 4,  ..., 8, 4, 8])\n",
      "torch.Size([50000, 784])\n",
      "tensor(0) tensor(9)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "x_train, y_train, x_valid, y_valid = map(\n",
    "    torch.tensor, (x_train, y_train, x_valid, y_valid)\n",
    ")\n",
    "n, c = x_train.shape\n",
    "x_train, x_train.shape, y_train.min(), y_train.max()\n",
    "print(x_train, y_train)\n",
    "print(x_train.shape)\n",
    "print(y_train.min(), y_train.max())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### torch.nn.functional 很多层和函数在这里都会见到\n",
    "\n",
    "torch.nn.functional中有很多功能，后续会常用的。那什么时候使用nn.Module，什么时候使用nn.functional呢？一般情况下，如果模型有可学习的参数，最好用nn.Module，其他情况nn.functional相对更简单一些"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "loss_func = F.cross_entropy\n",
    "\n",
    "def model(xb):\n",
    "    return xb.mm(weights) + bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(10.7988, grad_fn=<NllLossBackward>)\n"
     ]
    }
   ],
   "source": [
    "bs = 64\n",
    "xb = x_train[0:bs]  # a mini-batch from x\n",
    "yb = y_train[0:bs]\n",
    "weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True) \n",
    "bs = 64\n",
    "bias = torch.zeros(10, requires_grad=True)\n",
    "\n",
    "print(loss_func(model(xb), yb))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 创建一个model来更简化代码"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数\n",
    "- 无需写反向传播函数，nn.Module能够利用autograd自动实现反向传播\n",
    "- Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "class Mnist_NN(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden1 = nn.Linear(784, 128)\n",
    "        self.hidden2 = nn.Linear(128, 256)\n",
    "        self.out  = nn.Linear(256, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.hidden1(x))\n",
    "        x = F.relu(self.hidden2(x))\n",
    "        x = self.out(x)\n",
    "        return x\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mnist_NN(\n",
      "  (hidden1): Linear(in_features=784, out_features=128, bias=True)\n",
      "  (hidden2): Linear(in_features=128, out_features=256, bias=True)\n",
      "  (out): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = Mnist_NN()\n",
    "print(net)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以打印我们定义好名字里的权重和偏置项"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "hidden1.weight Parameter containing:\n",
      "tensor([[ 0.0018,  0.0218,  0.0036,  ..., -0.0286, -0.0166,  0.0089],\n",
      "        [-0.0349,  0.0268,  0.0328,  ...,  0.0263,  0.0200, -0.0137],\n",
      "        [ 0.0061,  0.0060, -0.0351,  ...,  0.0130, -0.0085,  0.0073],\n",
      "        ...,\n",
      "        [-0.0231,  0.0195, -0.0205,  ..., -0.0207, -0.0103, -0.0223],\n",
      "        [-0.0299,  0.0305,  0.0098,  ...,  0.0184, -0.0247, -0.0207],\n",
      "        [-0.0306, -0.0252, -0.0341,  ...,  0.0136, -0.0285,  0.0057]],\n",
      "       requires_grad=True) torch.Size([128, 784])\n",
      "hidden1.bias Parameter containing:\n",
      "tensor([ 0.0072, -0.0269, -0.0320, -0.0162,  0.0102,  0.0189, -0.0118, -0.0063,\n",
      "        -0.0277,  0.0349,  0.0267, -0.0035,  0.0127, -0.0152, -0.0070,  0.0228,\n",
      "        -0.0029,  0.0049,  0.0072,  0.0002, -0.0356,  0.0097, -0.0003, -0.0223,\n",
      "        -0.0028, -0.0120, -0.0060, -0.0063,  0.0237,  0.0142,  0.0044, -0.0005,\n",
      "         0.0349, -0.0132,  0.0138, -0.0295, -0.0299,  0.0074,  0.0231,  0.0292,\n",
      "        -0.0178,  0.0046,  0.0043, -0.0195,  0.0175, -0.0069,  0.0228,  0.0169,\n",
      "         0.0339,  0.0245, -0.0326, -0.0260, -0.0029,  0.0028,  0.0322, -0.0209,\n",
      "        -0.0287,  0.0195,  0.0188,  0.0261,  0.0148, -0.0195, -0.0094, -0.0294,\n",
      "        -0.0209, -0.0142,  0.0131,  0.0273,  0.0017,  0.0219,  0.0187,  0.0161,\n",
      "         0.0203,  0.0332,  0.0225,  0.0154,  0.0169, -0.0346, -0.0114,  0.0277,\n",
      "         0.0292, -0.0164,  0.0001, -0.0299, -0.0076, -0.0128, -0.0076, -0.0080,\n",
      "        -0.0209, -0.0194, -0.0143,  0.0292, -0.0316, -0.0188, -0.0052,  0.0013,\n",
      "        -0.0247,  0.0352, -0.0253, -0.0306,  0.0035, -0.0253,  0.0167, -0.0260,\n",
      "        -0.0179, -0.0342,  0.0033, -0.0287, -0.0272,  0.0238,  0.0323,  0.0108,\n",
      "         0.0097,  0.0219,  0.0111,  0.0208, -0.0279,  0.0324, -0.0325, -0.0166,\n",
      "        -0.0010, -0.0007,  0.0298,  0.0329,  0.0012, -0.0073, -0.0010,  0.0057],\n",
      "       requires_grad=True) torch.Size([128])\n",
      "hidden2.weight Parameter containing:\n",
      "tensor([[-0.0383, -0.0649,  0.0665,  ..., -0.0312,  0.0394, -0.0801],\n",
      "        [-0.0189, -0.0342,  0.0431,  ..., -0.0321,  0.0072,  0.0367],\n",
      "        [ 0.0289,  0.0780,  0.0496,  ...,  0.0018, -0.0604, -0.0156],\n",
      "        ...,\n",
      "        [-0.0360,  0.0394, -0.0615,  ...,  0.0233, -0.0536, -0.0266],\n",
      "        [ 0.0416,  0.0082, -0.0345,  ...,  0.0808, -0.0308, -0.0403],\n",
      "        [-0.0477,  0.0136, -0.0408,  ...,  0.0180, -0.0316, -0.0782]],\n",
      "       requires_grad=True) torch.Size([256, 128])\n",
      "hidden2.bias Parameter containing:\n",
      "tensor([-0.0694, -0.0363, -0.0178,  0.0206, -0.0875, -0.0876, -0.0369, -0.0386,\n",
      "         0.0642, -0.0738, -0.0017, -0.0243, -0.0054,  0.0757, -0.0254,  0.0050,\n",
      "         0.0519, -0.0695,  0.0318, -0.0042, -0.0189, -0.0263, -0.0627, -0.0691,\n",
      "         0.0713, -0.0696, -0.0672,  0.0297,  0.0102,  0.0040,  0.0830,  0.0214,\n",
      "         0.0714,  0.0327, -0.0582, -0.0354,  0.0621,  0.0475,  0.0490,  0.0331,\n",
      "        -0.0111, -0.0469, -0.0695, -0.0062, -0.0432, -0.0132, -0.0856, -0.0219,\n",
      "        -0.0185, -0.0517,  0.0017, -0.0788, -0.0403,  0.0039,  0.0544, -0.0496,\n",
      "         0.0588, -0.0068,  0.0496,  0.0588, -0.0100,  0.0731,  0.0071, -0.0155,\n",
      "        -0.0872, -0.0504,  0.0499,  0.0628, -0.0057,  0.0530, -0.0518, -0.0049,\n",
      "         0.0767,  0.0743,  0.0748, -0.0438,  0.0235, -0.0809,  0.0140, -0.0374,\n",
      "         0.0615, -0.0177,  0.0061, -0.0013, -0.0138, -0.0750, -0.0550,  0.0732,\n",
      "         0.0050,  0.0778,  0.0415,  0.0487,  0.0522,  0.0867, -0.0255, -0.0264,\n",
      "         0.0829,  0.0599,  0.0194,  0.0831, -0.0562,  0.0487, -0.0411,  0.0237,\n",
      "         0.0347, -0.0194, -0.0560, -0.0562, -0.0076,  0.0459, -0.0477,  0.0345,\n",
      "        -0.0575, -0.0005,  0.0174,  0.0855, -0.0257, -0.0279, -0.0348, -0.0114,\n",
      "        -0.0823, -0.0075, -0.0524,  0.0331,  0.0387, -0.0575,  0.0068, -0.0590,\n",
      "        -0.0101, -0.0880, -0.0375,  0.0033, -0.0172, -0.0641, -0.0797,  0.0407,\n",
      "         0.0741, -0.0041, -0.0608,  0.0672, -0.0464, -0.0716, -0.0191, -0.0645,\n",
      "         0.0397,  0.0013,  0.0063,  0.0370,  0.0475, -0.0535,  0.0721, -0.0431,\n",
      "         0.0053, -0.0568, -0.0228, -0.0260, -0.0784, -0.0148,  0.0229, -0.0095,\n",
      "        -0.0040,  0.0025,  0.0781,  0.0140, -0.0561,  0.0384, -0.0011, -0.0366,\n",
      "         0.0345,  0.0015,  0.0294, -0.0734, -0.0852, -0.0015, -0.0747, -0.0100,\n",
      "         0.0801, -0.0739,  0.0611,  0.0536,  0.0298, -0.0097,  0.0017, -0.0398,\n",
      "         0.0076, -0.0759, -0.0293,  0.0344, -0.0463, -0.0270,  0.0447,  0.0814,\n",
      "        -0.0193, -0.0559,  0.0160,  0.0216, -0.0346,  0.0316,  0.0881, -0.0652,\n",
      "        -0.0169,  0.0117, -0.0107, -0.0754, -0.0231, -0.0291,  0.0210,  0.0427,\n",
      "         0.0418,  0.0040,  0.0762,  0.0645, -0.0368, -0.0229, -0.0569, -0.0881,\n",
      "        -0.0660,  0.0297,  0.0433, -0.0777,  0.0212, -0.0601,  0.0795, -0.0511,\n",
      "        -0.0634,  0.0720,  0.0016,  0.0693, -0.0547, -0.0652, -0.0480,  0.0759,\n",
      "         0.0194, -0.0328, -0.0211, -0.0025, -0.0055, -0.0157,  0.0817,  0.0030,\n",
      "         0.0310, -0.0735,  0.0160, -0.0368,  0.0528, -0.0675, -0.0083, -0.0427,\n",
      "        -0.0872,  0.0699,  0.0795, -0.0738, -0.0639,  0.0350,  0.0114,  0.0303],\n",
      "       requires_grad=True) torch.Size([256])\n",
      "out.weight Parameter containing:\n",
      "tensor([[ 0.0232, -0.0571,  0.0439,  ..., -0.0417, -0.0237,  0.0183],\n",
      "        [ 0.0210,  0.0607,  0.0277,  ..., -0.0015,  0.0571,  0.0502],\n",
      "        [ 0.0297, -0.0393,  0.0616,  ...,  0.0131, -0.0163, -0.0239],\n",
      "        ...,\n",
      "        [ 0.0416,  0.0309, -0.0441,  ..., -0.0493,  0.0284, -0.0230],\n",
      "        [ 0.0404, -0.0564,  0.0442,  ..., -0.0271, -0.0526, -0.0554],\n",
      "        [-0.0404, -0.0049, -0.0256,  ..., -0.0262, -0.0130,  0.0057]],\n",
      "       requires_grad=True) torch.Size([10, 256])\n",
      "out.bias Parameter containing:\n",
      "tensor([-0.0536,  0.0007,  0.0227, -0.0072, -0.0168, -0.0125, -0.0207, -0.0558,\n",
      "         0.0579, -0.0439], requires_grad=True) torch.Size([10])\n"
     ]
    }
   ],
   "source": [
    "for name, parameter in net.named_parameters():\n",
    "    print(name, parameter,parameter.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用TensorDataset和DataLoader来简化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import TensorDataset\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "train_ds = TensorDataset(x_train, y_train)\n",
    "train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)\n",
    "\n",
    "valid_ds = TensorDataset(x_valid, y_valid)\n",
    "valid_dl = DataLoader(valid_ds, batch_size=bs * 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_data(train_ds, valid_ds, bs):\n",
    "    return (\n",
    "        DataLoader(train_ds, batch_size=bs, shuffle=True),\n",
    "        DataLoader(valid_ds, batch_size=bs * 2),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 一般在训练模型时加上model.train()，这样会正常使用Batch Normalization和 Dropout\n",
    "- 测试的时候一般选择model.eval()，这样就不会使用Batch Normalization和 Dropout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def fit(steps, model, loss_func, opt, train_dl, valid_dl):\n",
    "    for step in range(steps):\n",
    "        model.train()\n",
    "        for xb, yb in train_dl:\n",
    "            loss_batch(model, loss_func, xb, yb, opt)\n",
    "\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            losses, nums = zip(\n",
    "                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]\n",
    "            )\n",
    "        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)\n",
    "        print('当前step:'+str(step), '验证集损失：'+str(val_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch import optim\n",
    "def get_model():\n",
    "    model = Mnist_NN()\n",
    "    return model, optim.SGD(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def loss_batch(model, loss_func, xb, yb, opt=None):\n",
    "    loss = loss_func(model(xb), yb)\n",
    "\n",
    "    if opt is not None:\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.zero_grad()\n",
    "\n",
    "    return loss.item(), len(xb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三行搞定！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前step:0 验证集损失：2.2796445930480957\n",
      "当前step:1 验证集损失：2.2440698066711424\n",
      "当前step:2 验证集损失：2.1889826164245605\n",
      "当前step:3 验证集损失：2.0985311767578123\n",
      "当前step:4 验证集损失：1.9517273582458496\n",
      "当前step:5 验证集损失：1.7341805934906005\n",
      "当前step:6 验证集损失：1.4719875366210937\n",
      "当前step:7 验证集损失：1.2273896869659424\n",
      "当前step:8 验证集损失：1.0362271406173706\n",
      "当前step:9 验证集损失：0.8963696184158325\n",
      "当前step:10 验证集损失：0.7927186088562012\n",
      "当前step:11 验证集损失：0.7141492074012756\n",
      "当前step:12 验证集损失：0.6529350900650024\n",
      "当前step:13 验证集损失：0.60417300491333\n",
      "当前step:14 验证集损失：0.5643046331882476\n",
      "当前step:15 验证集损失：0.5317994566917419\n",
      "当前step:16 验证集损失：0.5047958114624024\n",
      "当前step:17 验证集损失：0.4813900615692139\n",
      "当前step:18 验证集损失：0.4618900228500366\n",
      "当前step:19 验证集损失：0.4443243554592133\n",
      "当前step:20 验证集损失：0.4297310716629028\n",
      "当前step:21 验证集损失：0.416976597738266\n",
      "当前step:22 验证集损失：0.406348459148407\n",
      "当前step:23 验证集损失：0.3963301926612854\n",
      "当前step:24 验证集损失：0.38733808159828187\n"
     ]
    }
   ],
   "source": [
    "train_dl, valid_dl = get_data(train_ds, valid_ds, bs)\n",
    "model, opt = get_model()\n",
    "fit(25, model, loss_func, opt, train_dl, valid_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
